# -*- coding: utf-8 -*-

import torch.nn as nn


class SegmentEmbedding(nn.Embedding):

    def __init__(self, embed_size=512):
        """
        :param embed_size: 维度
        """
        super().__init__(3, embed_size, padding_idx=0)
